#!/usr/bin/env python3
"""Generate CasADi-based quadrotor flatness C sources.

The resulting `.c` and `.h` files remove the runtime dependency on CasADi by
embedding the quadrotor model and its derivatives directly into the extension
build.
"""

from __future__ import annotations

import argparse
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Sequence

import casadi as ca
import yaml

PROJECT_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_CONFIG_PATH = PROJECT_ROOT / "config" / "casadi_quadrotor_flatness.yaml"
GENERATED_SRC = PROJECT_ROOT / "src" / "minco_trajectory" / "src" / "casadi_generated"
GENERATED_INCLUDE = PROJECT_ROOT / "src" / "minco_trajectory" / "include" / "casadi_generated"


@dataclass(frozen=True)
class GenerationArtifacts:
    forward: ca.Function
    backward: ca.Function


@dataclass(frozen=True)
class QuadrotorFlatnessConfig:
    mass: float = 1.0
    gravity: float = 9.81
    horizontal_drag: float = 0.0
    vertical_drag: float = 0.0
    parasitic_drag: float = 0.0
    speed_smooth: float = 1.0e-3

    @classmethod
    def from_yaml(cls, path: Path) -> "QuadrotorFlatnessConfig":
        raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
        if not isinstance(raw, dict):
            raise ValueError(f"Flatness config must be a mapping, got {type(raw)!r}")
        node = raw.get("flatness", raw)
        if not isinstance(node, dict):
            raise ValueError("Flatness node must be a mapping of parameter names")
        defaults = cls()
        return cls(
            mass=float(node.get("mass", defaults.mass)),
            gravity=float(node.get("gravity", defaults.gravity)),
            horizontal_drag=float(node.get("horizontal_drag", defaults.horizontal_drag)),
            vertical_drag=float(node.get("vertical_drag", defaults.vertical_drag)),
            parasitic_drag=float(node.get("parasitic_drag", defaults.parasitic_drag)),
            speed_smooth=float(node.get("speed_smooth", defaults.speed_smooth)),
        )


def _build_functions(config: QuadrotorFlatnessConfig) -> GenerationArtifacts:
    vel = ca.SX.sym("velocity", 3)
    acc = ca.SX.sym("acceleration", 3)
    jer = ca.SX.sym("jerk", 3)
    yaw = ca.SX.sym("yaw")
    yaw_rate = ca.SX.sym("yaw_rate")

    mass = ca.SX(config.mass)
    grav = ca.SX(config.gravity)
    dh = ca.SX(config.horizontal_drag)
    dv = ca.SX(config.vertical_drag)
    cp = ca.SX(config.parasitic_drag)
    veps = ca.SX(config.speed_smooth)

    v0, v1, v2 = vel[0], vel[1], vel[2]
    a0, a1, a2 = acc[0], acc[1], acc[2]
    j0, j1, j2 = jer[0], jer[1], jer[2]

    cp_term = ca.sqrt(v0 * v0 + v1 * v1 + v2 * v2 + veps)
    w_term = 1.0 + cp * cp_term
    w0 = w_term * v0
    w1 = w_term * v1
    w2 = w_term * v2
    dh_over_m = dh / mass

    zu0 = a0 + dh_over_m * w0
    zu1 = a1 + dh_over_m * w1
    zu2 = a2 + dh_over_m * w2 + grav

    zu_sqr0 = zu0 * zu0
    zu_sqr1 = zu1 * zu1
    zu_sqr2 = zu2 * zu2
    zu01 = zu0 * zu1
    zu12 = zu1 * zu2
    zu02 = zu0 * zu2
    zu_sqr_norm = zu_sqr0 + zu_sqr1 + zu_sqr2
    zu_norm = ca.sqrt(zu_sqr_norm)

    z0 = zu0 / zu_norm
    z1 = zu1 / zu_norm
    z2 = zu2 / zu_norm

    ng_den = zu_sqr_norm * zu_norm
    ng00 = (zu_sqr1 + zu_sqr2) / ng_den
    ng01 = -zu01 / ng_den
    ng02 = -zu02 / ng_den
    ng11 = (zu_sqr0 + zu_sqr2) / ng_den
    ng12 = -zu12 / ng_den
    ng22 = (zu_sqr0 + zu_sqr1) / ng_den

    v_dot_a = v0 * a0 + v1 * a1 + v2 * a2
    dw_term = cp * v_dot_a / cp_term
    dw0 = w_term * a0 + dw_term * v0
    dw1 = w_term * a1 + dw_term * v1
    dw2 = w_term * a2 + dw_term * v2

    dz_term0 = j0 + dh_over_m * dw0
    dz_term1 = j1 + dh_over_m * dw1
    dz_term2 = j2 + dh_over_m * dw2

    dz0 = ng00 * dz_term0 + ng01 * dz_term1 + ng02 * dz_term2
    dz1 = ng01 * dz_term0 + ng11 * dz_term1 + ng12 * dz_term2
    dz2 = ng02 * dz_term0 + ng12 * dz_term1 + ng22 * dz_term2

    f_term0 = mass * a0 + dv * w0
    f_term1 = mass * a1 + dv * w1
    f_term2 = mass * (a2 + grav) + dv * w2

    thrust = z0 * f_term0 + z1 * f_term1 + z2 * f_term2

    tilt_den = ca.sqrt(2.0 * (1.0 + z2))
    tilt0 = 0.5 * tilt_den
    tilt1 = -z1 / tilt_den
    tilt2 = z0 / tilt_den

    c_half_psi = ca.cos(0.5 * yaw)
    s_half_psi = ca.sin(0.5 * yaw)

    quat0 = tilt0 * c_half_psi
    quat1 = tilt1 * c_half_psi + tilt2 * s_half_psi
    quat2 = tilt2 * c_half_psi - tilt1 * s_half_psi
    quat3 = tilt0 * s_half_psi

    c_psi = ca.cos(yaw)
    s_psi = ca.sin(yaw)
    omg_den = z2 + 1.0
    omg_term = dz2 / omg_den

    omg0 = dz0 * s_psi - dz1 * c_psi - (z0 * s_psi - z1 * c_psi) * omg_term
    omg1 = dz0 * c_psi + dz1 * s_psi - (z0 * c_psi + z1 * s_psi) * omg_term
    omg2 = (z1 * dz0 - z0 * dz1) / omg_den + yaw_rate

    forward_out = ca.vertcat(thrust, quat0, quat1, quat2, quat3, omg0, omg1, omg2)

    forward = ca.Function(
        "casadi_quadrotor_flatness_forward",
        [vel, acc, jer, yaw, yaw_rate],
        [forward_out],
        ["velocity", "acceleration", "jerk", "yaw", "yaw_rate"],
        ["flatness_outputs"],
    )

    pos_grad = ca.SX.sym("position_gradient", 3)
    vel_grad = ca.SX.sym("velocity_gradient", 3)
    thr_grad = ca.SX.sym("thrust_gradient")
    quat_grad = ca.SX.sym("quaternion_gradient", 4)
    omg_grad = ca.SX.sym("angular_velocity_gradient", 3)

    state = ca.vertcat(vel, acc, jer, yaw, yaw_rate)
    jac_forward = ca.jacobian(forward_out, state)

    output_grad = ca.vertcat(thr_grad, quat_grad, omg_grad)
    input_grad = ca.vertcat(
        vel_grad,
        ca.SX.zeros(3),
        ca.SX.zeros(3),
        ca.SX.zeros(1),
        ca.SX.zeros(1),
    )

    total_grad = ca.mtimes(jac_forward.T, output_grad) + input_grad

    backward_out = ca.vertcat(
        pos_grad,
        total_grad[0:3],
        total_grad[3:6],
        total_grad[6:9],
        total_grad[9],
        total_grad[10],
    )

    backward = ca.Function(
        "casadi_quadrotor_flatness_backward",
        [
            vel,
            acc,
            jer,
            yaw,
            yaw_rate,
            pos_grad,
            vel_grad,
            thr_grad,
            quat_grad,
            omg_grad,
        ],
        [backward_out],
        [
            "velocity",
            "acceleration",
            "jerk",
            "yaw",
            "yaw_rate",
            "position_gradient",
            "velocity_gradient",
            "thrust_gradient",
            "quaternion_gradient",
            "angular_velocity_gradient",
        ],
        ["flatness_backward_outputs"],
    )

    return GenerationArtifacts(forward=forward, backward=backward)


def _write_config_header(config: QuadrotorFlatnessConfig) -> None:
    GENERATED_INCLUDE.mkdir(parents=True, exist_ok=True)
    header_path = GENERATED_INCLUDE / "quadrotor_flatness_config.hpp"
    template = """#pragma once\n\n#include \"flatness.hpp\"\n\nnamespace minco::flatness::casadi_generated\n{{\n\ninline constexpr DefaultConfig kEmbeddedConfig{{\n    .mass            = {mass:.17g},\n    .gravity         = {gravity:.17g},\n    .horizontal_drag = {horizontal_drag:.17g},\n    .vertical_drag   = {vertical_drag:.17g},\n    .parasitic_drag  = {parasitic_drag:.17g},\n    .speed_smooth    = {speed_smooth:.17g},\n}};\n\n}}  // namespace minco::flatness::casadi_generated\n\n"""

    header_path.write_text(
        template.format(
            mass=config.mass,
            gravity=config.gravity,
            horizontal_drag=config.horizontal_drag,
            vertical_drag=config.vertical_drag,
            parasitic_drag=config.parasitic_drag,
            speed_smooth=config.speed_smooth,
        ),
        encoding="utf-8",
    )


def _generate_casadi_sources(functions: GenerationArtifacts) -> None:
    GENERATED_SRC.mkdir(parents=True, exist_ok=True)

    with tempfile.TemporaryDirectory() as tmpdir:
        tmp_path = Path(tmpdir)
        codegen = ca.CodeGenerator("quadrotor_flatness.c", {"with_header": True})
        codegen.add(functions.forward)
        codegen.add(functions.backward)
        codegen.generate(str(tmp_path) + "/")

        (tmp_path / "quadrotor_flatness.c").replace(GENERATED_SRC / "quadrotor_flatness.c")
        (tmp_path / "quadrotor_flatness.h").replace(GENERATED_INCLUDE / "quadrotor_flatness.h")


def _parse_args(argv: Sequence[str] | None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Generate CasADi flatness sources")
    parser.add_argument(
        "--config",
        type=Path,
        default=DEFAULT_CONFIG_PATH,
        help="Path to the YAML config consumed during CasADi code generation",
    )
    return parser.parse_args(argv)


def main(argv: Sequence[str] | None = None) -> int:
    args = _parse_args(argv)
    config = QuadrotorFlatnessConfig.from_yaml(args.config)
    _write_config_header(config)
    functions = _build_functions(config)
    _generate_casadi_sources(functions)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
